%==========================================================================
% This script compares multi-spectral imagery (MSI) noise removal methods
% listed as follows:
%   1. band-wise K-SVD
%   2. band-wise BM3D
%   3. Integral K-SVD
%   4. 3D NLM
%   5. BM4D
%   6. LRTA
%   7. PARAFAC
%   8. tensor dictionary learning
%
% Five quality assessment (QA) indices -- PSNR, SSIM, FSIM, ERGAS and SAM
% -- are calculated for each methods after denoising.
%
% See also TensorDL, ksvddenoise, bm3d, NLM3D, bm4d, LRTA, PARAFAC and MSIQA
%
% by Yi Peng
%==========================================================================

clear; clc;

addpath(genpath('..\'));

%% Set enable bits
EN_BWKSVD = 1;
EN_BWBM3D = 1;
EN_KSVD = 1;
EN_NLM3D = 1;
EN_BM4D = 1;
EN_LRTA = 1;
EN_PARAFAC = 1;

%% Set noise level
kappa = 4;       % smaller kappa <--> heavier Poisson noise
sigma_ratio = 0.2;  % higher sigma_ratio <--> heavier Gaussian noise

%% Load the MSI
% NOTE: make sure the MSI is of size height x width x nbands and in range [0, 1]
% In this example, the MSI is stored in a tensor named 'msi'
filename = 'pompoms.mat';
load(filename);
[height, width, nbands] = size(msi);
msi_sz = [height, width, nbands];

psnr = zeros(9, 1);
ssim = zeros(9, 1);
fsim = zeros(9, 1);
ergas = zeros(9, 1);
sam = zeros(9, 1);
time = zeros(9, 1);

%% Add Poisson and Gaussian noise
peak = 2^kappa;               % expected peak value
sigma = peak * sigma_ratio;     % sigma of Gaussian distribution
noisy_msi = poissrnd(msi * peak); % add Poisson noise
noisy_msi = noisy_msi + sigma * randn(msi_sz);  % add Gaussian noise

[psnr(1), ssim(1), fsim(1), ergas(1), sam(1)] = ...
    MSIQA(msi * 255, noisy_msi / peak * 255);

%% Apply VST
VST_msi = GenAnscombe_forward(noisy_msi, sigma);    % VST via anscombe transform
max_VST_msi = max(VST_msi(:));
min_VST_msi = min(VST_msi(:));
VST_msi = (VST_msi - min_VST_msi) / (max_VST_msi - min_VST_msi);    % scale to [0, 1]
VST_sigma = 1 / (max_VST_msi - min_VST_msi);

%% Denoise by band-wise K-SVD
clean_msi_bwksvd = [];
if EN_BWKSVD
    fprintf('Denoising by band-wise K-SVD ...\n');
    bwksvd_params.blocksize = [8, 8];
    bwksvd_params.sigma = VST_sigma;
    bwksvd_params.memusage = 'high';
    bwksvd_params.trainnum = 200;
    bwksvd_params.stepsize = [4, 4];
    bwksvd_params.maxval = 1;
    bwksvd_params.dictsize = 128;
    tic;
    for ch = 1:31
        bwksvd_params.x = VST_msi(:, :, ch);
        clean_msi_bwksvd(:, :, ch) = ksvddenoise(bwksvd_params, 0);
    end
    time(2) = toc;
    clean_msi_bwksvd = clean_msi_bwksvd * (max_VST_msi - min_VST_msi) + min_VST_msi;
    clean_msi_bwksvd = GenAnscombe_inverse_exact_unbiased(clean_msi_bwksvd, sigma);
    [psnr(2), ssim(2), fsim(2), ergas(2), sam(2)] = ...
        MSIQA(msi * 255, clean_msi_bwksvd / peak * 255);
end

%% Denoise by band-wise BM3D
clean_msi_bwbm3d = [];
if EN_BWBM3D
    fprintf('Denoising by band-wise BM3D ...\n');
    tic;
    for ch = 1:31
        [~, clean_msi_bwbm3d(:, :, ch)] = BM3D(1, VST_msi(:, :, ch), VST_sigma*255);
    end
    time(3) = toc;
    clean_msi_bwbm3d = clean_msi_bwbm3d * (max_VST_msi - min_VST_msi) + min_VST_msi;
    clean_msi_bwbm3d = GenAnscombe_inverse_exact_unbiased(clean_msi_bwbm3d, sigma);
    [psnr(3), ssim(3), fsim(3), ergas(3), sam(3)] = ...
        MSIQA(msi * 255, clean_msi_bwbm3d / peak * 255);
end

%% Denoise by integral K-SVD
clean_msi_ksvd = [];
if EN_KSVD
    fprintf('Denoising by integral K-SVD ...\n');
    ksvd_params.blocksize = [8, 8, 7];
    ksvd_params.sigma = VST_sigma;
    ksvd_params.memusage = 'high';
    ksvd_params.trainnum = 2000;
    ksvd_params.stepsize = [4, 4, 4];
    ksvd_params.maxval = 1;
    ksvd_params.dictsize = 500;
    ksvd_params.x = VST_msi;
    tic;
    clean_msi_ksvd = ksvddenoise(ksvd_params, 0);
    time(4) = toc;
    clean_msi_ksvd = clean_msi_ksvd * (max_VST_msi - min_VST_msi) + min_VST_msi;
    clean_msi_ksvd = GenAnscombe_inverse_exact_unbiased(clean_msi_ksvd, sigma);
    [psnr(4), ssim(4), fsim(4), ergas(4), sam(4)] = ...
        MSIQA(msi * 255, clean_msi_ksvd /peak * 255);
end

%% Denoise by NLM3D
clean_msi_nlm3d = [];
if EN_NLM3D
    fprintf('Denoising by 3D NLM ...\n');
    tic;
    clean_msi_nlm3d = NLM3D(VST_msi, 5, 2, 3, 0);
    time(5) = toc;
    clean_msi_nlm3d = clean_msi_nlm3d * (max_VST_msi - min_VST_msi) + min_VST_msi;
    clean_msi_nlm3d = GenAnscombe_inverse_exact_unbiased(clean_msi_nlm3d, sigma);
    [psnr(5), ssim(5), fsim(5), ergas(5), sam(5)] = ...
        MSIQA(msi * 255, clean_msi_nlm3d / peak * 255);
end

%% Denoise by BM4D
clean_msi_bm4d = [];
if EN_BM4D
    fprintf('Denoising by BM4D ...\n');
    tic;
    [~, clean_msi_bm4d] = bm4d(1, VST_msi, VST_sigma);
    time(6) = toc;
    clean_msi_bm4d = clean_msi_bm4d * (max_VST_msi - min_VST_msi) + min_VST_msi;
    clean_msi_bm4d = GenAnscombe_inverse_exact_unbiased(clean_msi_bm4d, sigma);
    [psnr(6), ssim(6), fsim(6), ergas(6), sam(6)] = ...
        MSIQA(msi * 255, clean_msi_bm4d / peak * 255);
end

%% Denoise by LRTA
clean_msi_lrta = [];
if EN_LRTA
    fprintf('Denoising by LRTA ...\n');
    tic;
    clean_msi_lrta = double(LRTA(tensor(VST_msi)));
    time(7) = toc;
    clean_msi_lrta = clean_msi_lrta * (max_VST_msi - min_VST_msi) + min_VST_msi;
    clean_msi_lrta = GenAnscombe_inverse_exact_unbiased(clean_msi_lrta, sigma);
    [psnr(7), ssim(7), fsim(7), ergas(7), sam(7)] = ...
        MSIQA(msi * 255, clean_msi_lrta / peak * 255);
end

%% Denoise by PARAFAC
clean_msi_parafac = [];
if EN_PARAFAC
    fprintf('Denoising by PARAFAC ...\n');
    tic;
    clean_msi_parafac = PARAFAC(tensor(VST_msi));
    time(8) = toc;
    clean_msi_parafac = clean_msi_parafac * (max_VST_msi - min_VST_msi) + min_VST_msi;
    clean_msi_parafac = GenAnscombe_inverse_exact_unbiased(clean_msi_parafac, sigma);
    [psnr(8), ssim(8), fsim(8), ergas(8), sam(8)] = ...
        MSIQA(msi * 255, clean_msi_parafac / peak * 255);
end

%% Denoise by tensor dictionary learning
fprintf('Denoising by tensor dictionary learning ...\n');
vstbmtf_params.peak_value = 1;
vstbmtf_params.nsigma = VST_sigma;
tic;
clean_msi_tdl = TensorDL(VST_msi, vstbmtf_params);
time(9) = toc;
clean_msi_tdl = clean_msi_tdl * (max_VST_msi - min_VST_msi) + min_VST_msi;
clean_msi_tdl = GenAnscombe_inverse_exact_unbiased(clean_msi_tdl, sigma);
[psnr(9), ssim(9), fsim(9), ergas(9), sam(9)] = ...
    MSIQA(msi * 255, clean_msi_tdl / peak * 255);

%% Save the results
EN_NOISY = 1;
EN_TDL = 1;
ENABLE_BITS = [EN_NOISY, EN_BWKSVD, EN_BWBM3D, EN_KSVD, EN_NLM3D, ...
    EN_BM4D, EN_LRTA, EN_PARAFAC, EN_TDL];

% Transform each MSI to type 'uint8' for lighter storage
msi = uint8(msi * 255);
noisy_msi = uint8(noisy_msi / peak * 255);
clean_msi_bwksvd = uint8(clean_msi_bwksvd / peak * 255);
clean_msi_bwbm3d = uint8(clean_msi_bwbm3d / peak * 255);
clean_msi_ksvd = uint8(clean_msi_ksvd / peak * 255);
clean_msi_nlm3d = uint8(clean_msi_nlm3d / peak * 255);
clean_msi_bm4d = uint8(clean_msi_bm4d / peak * 255);
clean_msi_lrta = uint8(clean_msi_lrta / peak * 255);
clean_msi_parafac = uint8(clean_msi_parafac / peak * 255);
clean_msi_tdl = uint8(clean_msi_tdl / peak * 255);

save('results.mat', 'ENABLE_BITS', 'msi', 'noisy_msi', 'clean_msi_bm4d', ...
    'clean_msi_bwksvd', 'clean_msi_bwbm3d', 'clean_msi_ksvd', 'clean_msi_lrta', ...
    'clean_msi_nlm3d', 'clean_msi_parafac', 'clean_msi_tdl', ...
    'nbands', 'time', 'psnr', 'ssim', 'fsim', 'ergas', 'sam');